// The libMesh Finite Element Library.
// Copyright (C) 2002-2025 Benjamin S. Kirk, John W. Peterson, Roy H. Stogner

// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.

// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.

// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA



// Local includes
#include "libmesh/elem.h"
#include "libmesh/fe.h"
#include "libmesh/fe_interface.h"
#include "libmesh/fe_macro.h"
#include "libmesh/libmesh_logging.h"
#include "libmesh/quadrature.h"
#include "libmesh/tensor_value.h"
#include "libmesh/enum_elem_type.h"
#include "libmesh/quadrature_gauss.h"
#include "libmesh/libmesh_singleton.h"

namespace {
  // Put this outside a templated class, so we only get 1 warning
  // during our unit tests, not 1 warning for each of the zillion FE
  // specializations we test.
  void nonlagrange_dual_warning () {
    libmesh_warning("dual calculations have only been verified for the LAGRANGE family");
  }
}


namespace libMesh
{
// ------------------------------------------------------------
// Whether we cache the node locations, edge and face orientations of the last
// element we computed on as needed to avoid calling init_shape_functions and
// compute_shape_functions
static const bool * caching = nullptr;

class CachingSetup: public Singleton::Setup
{
  private:
    void setup() { caching = new bool(!on_command_line("--disable-caching")); }
  public:
    ~CachingSetup() { delete caching; caching = nullptr; }
} caching_setup;


// ------------------------------------------------------------
// FE class members
template <unsigned int Dim, FEFamily T>
FE<Dim,T>::FE (const FEType & fet) :
  FEGenericBase<typename FEOutputType<T>::type> (Dim,fet),
  last_side(INVALID_ELEM),
  last_edge(INVALID_ELEM)
{
  // Sanity check.  Make sure the
  // Family specified in the template instantiation
  // matches the one in the FEType object
  libmesh_assert_equal_to (T, this->get_family());
}


template <unsigned int Dim, FEFamily T>
unsigned int FE<Dim,T>::n_shape_functions () const
{
  if (this->_elem)
    return this->n_dofs (this->_elem,
                         this->fe_type.order + this->_p_level);

  return this->n_dofs (this->get_type(),
                       this->fe_type.order + this->_p_level);
}


template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::attach_quadrature_rule (QBase * q)
{
  libmesh_assert(q);
  this->qrule = q;
  // make sure we don't cache results from a previous quadrature rule
  this->_elem = nullptr;
  this->_elem_type = INVALID_ELEM;
  return;
}


template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::dofs_on_side(const Elem * const elem,
                             const Order o,
                             unsigned int s,
                             std::vector<unsigned int> & di,
                             const bool add_p_level)
{
  libmesh_assert(elem);
  libmesh_assert_less (s, elem->n_sides());

  di.clear();
  unsigned int nodenum = 0;
  const unsigned int n_nodes = elem->n_nodes();
  for (unsigned int n = 0; n != n_nodes; ++n)
    {
      const unsigned int n_dofs =
          n_dofs_at_node(*elem, static_cast<Order>(o + add_p_level*elem->p_level()), n);
      if (elem->is_node_on_side(n, s))
        for (unsigned int i = 0; i != n_dofs; ++i)
          di.push_back(nodenum++);
      else
        nodenum += n_dofs;
    }
}



template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::dofs_on_edge(const Elem * const elem,
                             const Order o,
                             unsigned int e,
                             std::vector<unsigned int> & di,
                             const bool add_p_level)
{
  libmesh_assert(elem);
  libmesh_assert_less (e, elem->n_edges());

  di.clear();
  unsigned int nodenum = 0;
  const unsigned int n_nodes = elem->n_nodes();
  for (unsigned int n = 0; n != n_nodes; ++n)
    {
      const unsigned int n_dofs =
          n_dofs_at_node(*elem, static_cast<Order>(o + add_p_level*elem->p_level()), n);
      if (elem->is_node_on_edge(n, e))
        for (unsigned int i = 0; i != n_dofs; ++i)
          di.push_back(nodenum++);
      else
        nodenum += n_dofs;
    }
}



template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::cache(const Elem * elem)
{
  cached_nodes.resize(elem->n_nodes());
  for (auto n : elem->node_index_range())
    cached_nodes[n] = elem->point(n);

  if (FEInterface::orientation_dependent(T))
    {
      cached_edges.resize(elem->n_edges());
      for (auto n : elem->edge_index_range())
        cached_edges[n] = elem->positive_edge_orientation(n);

      cached_faces.resize(elem->n_faces());
      for (auto n : elem->face_index_range())
        cached_faces[n] = elem->positive_face_orientation(n);
    }
}



template <unsigned int Dim, FEFamily T>
bool FE<Dim,T>::matches_cache(const Elem * elem)
{
  bool m = cached_nodes.size() == elem->n_nodes();
  for (unsigned n = 1; m && n < elem->n_nodes(); n++)
    m = (elem->point(n) - elem->point(0)).relative_fuzzy_equals(cached_nodes[n] - cached_nodes[0]);

  if (FEInterface::orientation_dependent(T))
    {
      m &= cached_edges.size() == elem->n_edges();
      for (unsigned n = 0; m && n < elem->n_edges(); n++)
        m = elem->positive_edge_orientation(n) == cached_edges[n];

      m &= cached_faces.size() == elem->n_faces();
      for (unsigned n = 0; m && n < elem->n_faces(); n++)
        m = elem->positive_face_orientation(n) == cached_faces[n];
    }

  return m;
}



template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::reinit(const Elem * elem,
                       const std::vector<Point> * const pts,
                       const std::vector<Real> * const weights)
{
  // We can be called with no element.  If we're evaluating SCALAR
  // dofs we'll still have work to do.
  // libmesh_assert(elem);

  // We're calculating now!  Time to determine what.
  this->determine_calculations();

  // Try to avoid calling init_shape_functions
  // even when shapes_need_reinit
  bool cached_elem_still_fits = false;

  // Most of the hard work happens when we have an actual element
  if (elem)
    {
      // Initialize the shape functions at the user-specified
      // points
      if (pts != nullptr)
        {
          // Set the type and p level for this element
          this->_elem = elem;
          this->_elem_type = elem->type();
          this->_elem_p_level = elem->p_level();
          this->_p_level = this->_add_p_level_in_reinit * elem->p_level();

          // Initialize the shape functions
          this->_fe_map->template init_reference_to_physical_map<Dim>
            (*pts, elem);
          this->init_shape_functions (*pts, elem);

          // The shape functions do not correspond to the qrule
          this->shapes_on_quadrature = false;
        }

      // If there are no user specified points, we use the
      // quadrature rule

      // update the type in accordance to the current cell
      // and reinit if the cell type has changed or (as in
      // the case of the hierarchics) the shape functions need
      // reinit, since they depend on the particular element shape
      else
        {
          libmesh_assert(this->qrule);
          this->qrule->init(*elem);

          if (this->qrule->shapes_need_reinit())
            this->shapes_on_quadrature = false;

          // We're not going to bother trying to cache nodal
          // points *and* weights for fancier mapping types.
          if (this->get_type() != elem->type()       ||
              (elem->runtime_topology() &&
               this->_elem != elem)                  ||
              this->_elem_p_level != elem->p_level() ||
              !this->shapes_on_quadrature            ||
              elem->mapping_type() != LAGRANGE_MAP)
            {
              // Set the type and p level for this element
              this->_elem = elem;
              this->_elem_type = elem->type();
              this->_elem_p_level = elem->p_level();
              this->_p_level = this->_add_p_level_in_reinit * elem->p_level();

              // Initialize the shape functions
              this->_fe_map->template init_reference_to_physical_map<Dim>
                (this->qrule->get_points(), elem);
              this->init_shape_functions (this->qrule->get_points(), elem);
            }
          else
            {
              this->_elem = elem;

              // Check if cached element's nodes, edge and face orientations still fit
              cached_elem_still_fits = this->matches_cache(elem);

              // Initialize the shape functions if needed
              if (this->shapes_need_reinit() && !cached_elem_still_fits)
                {
                  this->_fe_map->template init_reference_to_physical_map<Dim>
                    (this->qrule->get_points(), elem);
                  this->init_shape_functions (this->qrule->get_points(), elem);
                }
            }

          // Replace cached nodes, edge and face orientations if no longer fitting
          if (this->shapes_need_reinit() && !cached_elem_still_fits && *caching)
            this->cache(elem);

          // The shape functions correspond to the qrule
          this->shapes_on_quadrature = true;
        }
    }
  else // With no defined elem, so mapping or caching to
       // be done, and our "quadrature rule" is one point for nonlocal
       // (SCALAR) variables and zero points for local variables.
    {
      this->_elem = nullptr;
      this->_elem_type = INVALID_ELEM;
      this->_elem_p_level = 0;
      this->_p_level = 0;

      if (!pts)
        {
          if (T == SCALAR)
            {
              this->qrule->get_points() =
                std::vector<Point>(1,Point(0));

              this->qrule->get_weights() =
                std::vector<Real>(1,1);
            }
          else
            {
              this->qrule->get_points().clear();
              this->qrule->get_weights().clear();
            }

          this->init_shape_functions (this->qrule->get_points(), elem);
        }
      else
        this->init_shape_functions (*pts, elem);
    }

  // Compute the map for this element.
  if (pts != nullptr)
    {
      if (weights != nullptr)
        {
          this->_fe_map->compute_map (this->dim, *weights, elem, this->calculate_d2phi);
        }
      else
        {
          std::vector<Real> dummy_weights (pts->size(), 1.);
          this->_fe_map->compute_map (this->dim, dummy_weights, elem, this->calculate_d2phi);
        }
    }
  else
    {
      this->_fe_map->compute_map (this->dim, this->qrule->get_weights(), elem, this->calculate_d2phi);
    }

  // Compute the shape functions and the derivatives at all of the
  // quadrature points.
  if (!cached_elem_still_fits)
    {
      if (pts != nullptr)
        this->compute_shape_functions (elem,*pts);
      else
        this->compute_shape_functions(elem,this->qrule->get_points());
      if (this->calculate_dual)
      {
        if (T != LAGRANGE)
          nonlagrange_dual_warning();
        // Check if we need to calculate the dual coefficients based on the default QRule
        // We keep the default dual coeff calculation for the initial stage of the simulation
        // and in the middel of the simulation when a customized QRule is not provided.
        // This is used in MOOSE mortar-based contact. Currently, we re-compute dual_coeff
        // for all the elements on the mortar segment mesh by setting `calculate_default_dual_coeff' = false
        // in MOOSE (in `Assembly::reinitDual`) and use the customized QRule for calculating the dual shape coefficients
        // This is to be improved in the future
        if (elem && this->calculate_default_dual_coeff)
          this->reinit_default_dual_shape_coeffs(elem);
        // The dual shape functions relies on the customized shape functions
        // and the coefficient matrix, \p dual_coeff
        this->compute_dual_shape_functions();
      }
    }
}

template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::reinit_dual_shape_coeffs(const Elem * elem,
                                         const std::vector<Point> & pts,
                                         const std::vector<Real> & JxW)
{
  // Set the type and p level for this element
  this->_elem = elem;
  this->_elem_type = elem->type();
  this->_elem_p_level = elem->p_level();
  this->_p_level = this->_add_p_level_in_reinit * elem->p_level();

  const unsigned int n_shapes =
    this->n_dofs(elem, this->get_order());

  std::vector<std::vector<OutputShape>> phi_vals;
  phi_vals.resize(n_shapes);
  for (const auto i : make_range(phi_vals.size()))
    phi_vals[i].resize(pts.size());

  all_shapes(elem, this->get_order(), pts, phi_vals);
  this->compute_dual_shape_coeffs(JxW, phi_vals);
}

template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::reinit_default_dual_shape_coeffs (const Elem * elem)
{
  libmesh_assert(elem);

  FEType default_fe_type(this->get_order(), T);
  QGauss default_qrule(elem->dim(), default_fe_type.default_quadrature_order());
  default_qrule.init(*elem);
  // In preparation of computing dual_coeff, we compute the default shape
  // function values and use these to compute the dual shape coefficients.
  // The TRUE dual_phi values are computed in compute_dual_shape_functions()
  this->reinit_dual_shape_coeffs(elem, default_qrule.get_points(), default_qrule.get_weights());
  // we do not compute default dual coeff many times as this can be expensive
  this->set_calculate_default_dual_coeff(false);
}


template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::init_dual_shape_functions(const unsigned int n_shapes, const unsigned int n_qp)
{
  if (!this->calculate_dual)
    return;

  libmesh_assert_msg(this->calculate_phi,
                     "dual shape function calculation relies on "
                     "primal shape functions being calculated");

  this->dual_phi.resize(n_shapes);
  if (this->calculate_dphi)
    this->dual_dphi.resize(n_shapes);
#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
  if (this->calculate_d2phi)
    this->dual_d2phi.resize(n_shapes);
#endif

  for (auto i : index_range(this->dual_phi))
  {
    this->dual_phi[i].resize(n_qp);
    if (this->calculate_dphi)
      this->dual_dphi[i].resize(n_qp);
#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
  if (this->calculate_d2phi)
    this->dual_d2phi[i].resize(n_qp);
#endif
  }
}

template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::init_shape_functions(const std::vector<Point> & qp,
                                     const Elem * elem)
{
  // Start logging the shape function initialization
  LOG_SCOPE("init_shape_functions()", "FE");

  // The number of quadrature points.
  const unsigned int n_qp = cast_int<unsigned int>(qp.size());
  this->_n_total_qp = n_qp;

  // Number of shape functions in the finite element approximation
  // space.
  const unsigned int n_approx_shape_functions =
    this->n_dofs(elem, this->get_order());

  // Maybe we already have correctly-sized data?  Check data sizes,
  // and get ready to break out of a "loop" if all these resize()
  // calls are redundant.
  unsigned int old_n_qp = 0;
  do
    {
      // resize the vectors to hold current data
      // Phi are the shape functions used for the FE approximation
      // Phi_map are the shape functions used for the FE mapping
      if (this->calculate_phi)
        {
          if (this->phi.size() == n_approx_shape_functions)
            {
              old_n_qp = n_approx_shape_functions ? this->phi[0].size() : 0;
              break;
            }
          this->phi.resize     (n_approx_shape_functions);
        }
      if (this->calculate_dphi)
        {
          if (this->dphi.size() == n_approx_shape_functions)
            {
              old_n_qp = n_approx_shape_functions ? this->dphi[0].size() : 0;
              break;
            }
          this->dphi.resize    (n_approx_shape_functions);
          this->dphidx.resize  (n_approx_shape_functions);
          this->dphidy.resize  (n_approx_shape_functions);
          this->dphidz.resize  (n_approx_shape_functions);
        }

      if (this->calculate_dphiref)
        {
          if (Dim > 0)
            {
              if (this->dphidxi.size() == n_approx_shape_functions)
                {
                  old_n_qp = n_approx_shape_functions ? this->dphidxi[0].size() : 0;
                  break;
                }
              this->dphidxi.resize (n_approx_shape_functions);
            }

          if (Dim > 1)
            this->dphideta.resize      (n_approx_shape_functions);

          if (Dim > 2)
            this->dphidzeta.resize     (n_approx_shape_functions);
        }

      if (this->calculate_curl_phi && (FEInterface::field_type(T) == TYPE_VECTOR))
        this->curl_phi.resize(n_approx_shape_functions);

      if (this->calculate_div_phi && (FEInterface::field_type(T) == TYPE_VECTOR))
        this->div_phi.resize(n_approx_shape_functions);

#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
      if (this->calculate_d2phi)
        {
          if (this->d2phi.size() == n_approx_shape_functions)
            {
              old_n_qp = n_approx_shape_functions ? this->d2phi[0].size() : 0;
              break;
            }

          this->d2phi.resize     (n_approx_shape_functions);
          this->d2phidx2.resize  (n_approx_shape_functions);
          this->d2phidxdy.resize (n_approx_shape_functions);
          this->d2phidxdz.resize (n_approx_shape_functions);
          this->d2phidy2.resize  (n_approx_shape_functions);
          this->d2phidydz.resize (n_approx_shape_functions);
          this->d2phidz2.resize  (n_approx_shape_functions);

          if (Dim > 0)
            this->d2phidxi2.resize (n_approx_shape_functions);

          if (Dim > 1)
            {
              this->d2phidxideta.resize (n_approx_shape_functions);
              this->d2phideta2.resize   (n_approx_shape_functions);
            }
          if (Dim > 2)
            {
              this->d2phidxidzeta.resize  (n_approx_shape_functions);
              this->d2phidetadzeta.resize (n_approx_shape_functions);
              this->d2phidzeta2.resize    (n_approx_shape_functions);
            }
        }
#endif // ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
    }
  while (false);

  if (old_n_qp != n_qp)
    for (unsigned int i=0; i<n_approx_shape_functions; i++)
      {
        if (this->calculate_phi)
          this->phi[i].resize         (n_qp);

        if (this->calculate_dphi)
          {
            this->dphi[i].resize        (n_qp);
            this->dphidx[i].resize      (n_qp);
            this->dphidy[i].resize      (n_qp);
            this->dphidz[i].resize      (n_qp);
          }

        if (this->calculate_dphiref)
          {
            if (Dim > 0)
              this->dphidxi[i].resize(n_qp);

            if (Dim > 1)
              this->dphideta[i].resize(n_qp);

            if (Dim > 2)
              this->dphidzeta[i].resize(n_qp);
          }

        if (this->calculate_curl_phi && (FEInterface::field_type(T) == TYPE_VECTOR))
          this->curl_phi[i].resize(n_qp);

        if (this->calculate_div_phi && (FEInterface::field_type(T) == TYPE_VECTOR))
          this->div_phi[i].resize(n_qp);

#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
        if (this->calculate_d2phi)
          {
            this->d2phi[i].resize     (n_qp);
            this->d2phidx2[i].resize  (n_qp);
            this->d2phidxdy[i].resize (n_qp);
            this->d2phidxdz[i].resize (n_qp);
            this->d2phidy2[i].resize  (n_qp);
            this->d2phidydz[i].resize (n_qp);
            this->d2phidz2[i].resize  (n_qp);
            if (Dim > 0)
              this->d2phidxi2[i].resize (n_qp);
            if (Dim > 1)
              {
                this->d2phidxideta[i].resize (n_qp);
                this->d2phideta2[i].resize   (n_qp);
              }
            if (Dim > 2)
              {
                this->d2phidxidzeta[i].resize  (n_qp);
                this->d2phidetadzeta[i].resize (n_qp);
                this->d2phidzeta2[i].resize    (n_qp);
              }
          }
#endif // ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
      }


#ifdef LIBMESH_ENABLE_INFINITE_ELEMENTS
  //------------------------------------------------------------
  // Initialize the data fields, which should only be used for infinite
  // elements, to some sensible values, so that using a FE with the
  // variational formulation of an InfFE, correct element matrices are
  // returned

  {
    if (this->calculate_phi || this->calculate_dphi)
      {
        this->weight.resize  (n_qp);
        for (unsigned int p=0; p<n_qp; p++)
          this->weight[p] = 1.;
      }

    if (this->calculate_dphi)
      {
        this->dweight.resize (n_qp);
        this->dphase.resize  (n_qp);
        for (unsigned int p=0; p<n_qp; p++)
          {
            this->dweight[p].zero();
            this->dphase[p].zero();
          }
      }
  }
#endif // ifdef LIBMESH_ENABLE_INFINITE_ELEMENTS

        // Compute the values of the shape function derivatives
  if (this->calculate_dphiref && Dim > 0)
    {
      std::vector<std::vector<OutputShape>> * comps[3]
        { &this->dphidxi, &this->dphideta, &this->dphidzeta };
      FE<Dim,T>::all_shape_derivs(elem, this->fe_type.order, qp, comps, this->_add_p_level_in_reinit);
    }

  switch (Dim)
    {

      //------------------------------------------------------------
      // 0D
    case 0:
      {
        break;
      }

      //------------------------------------------------------------
      // 1D
    case 1:
      {
#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
        // Compute the value of shape function i Hessians at quadrature point p
        if (this->calculate_d2phi)
          for (unsigned int i=0; i<n_approx_shape_functions; i++)
            for (unsigned int p=0; p<n_qp; p++)
              this->d2phidxi2[i][p] = FE<Dim, T>::shape_second_deriv(
                  elem, this->fe_type.order, i, 0, qp[p], this->_add_p_level_in_reinit);
#endif // ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES

        break;
      }



      //------------------------------------------------------------
      // 2D
    case 2:
      {
#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
        // Compute the value of shape function i Hessians at quadrature point p
        if (this->calculate_d2phi)
          for (unsigned int i=0; i<n_approx_shape_functions; i++)
            for (unsigned int p=0; p<n_qp; p++)
              {
                this->d2phidxi2[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 0, qp[p], this->_add_p_level_in_reinit);
                this->d2phidxideta[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 1, qp[p], this->_add_p_level_in_reinit);
                this->d2phideta2[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 2, qp[p], this->_add_p_level_in_reinit);
              }
#endif // ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES


        break;
      }



      //------------------------------------------------------------
      // 3D
    case 3:
      {
#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES
        // Compute the value of shape function i Hessians at quadrature point p
        if (this->calculate_d2phi)
          for (unsigned int i=0; i<n_approx_shape_functions; i++)
            for (unsigned int p=0; p<n_qp; p++)
              {
                this->d2phidxi2[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 0, qp[p], this->_add_p_level_in_reinit);
                this->d2phidxideta[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 1, qp[p], this->_add_p_level_in_reinit);
                this->d2phideta2[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 2, qp[p], this->_add_p_level_in_reinit);
                this->d2phidxidzeta[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 3, qp[p], this->_add_p_level_in_reinit);
                this->d2phidetadzeta[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 4, qp[p], this->_add_p_level_in_reinit);
                this->d2phidzeta2[i][p] = FE<Dim, T>::shape_second_deriv(
                    elem, this->fe_type.order, i, 5, qp[p], this->_add_p_level_in_reinit);
              }
#endif // ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES

        break;
      }


    default:
      libmesh_error_msg("Invalid dimension Dim = " << Dim);
    }

  if (this->calculate_dual)
    this->init_dual_shape_functions(n_approx_shape_functions, n_qp);
}

template <unsigned int Dim, FEFamily T>
void
FE<Dim,T>::default_all_shape_derivs (const Elem * elem,
                                     const Order o,
                                     const std::vector<Point> & p,
                                     std::vector<std::vector<OutputShape>> * comps[3],
                                     const bool add_p_level)
{
  for (unsigned int d=0; d != Dim; ++d)
    {
      auto & comps_d = *comps[d];
      for (auto i : index_range(comps_d))
        FE<Dim,T>::shape_derivs
          (elem,o,i,d,p,comps_d[i],add_p_level);
    }
}


template <unsigned int Dim, FEFamily T>
void
FE<Dim,T>::default_side_nodal_soln(const Elem * elem, const Order o,
                                   const unsigned int side,
                                   const std::vector<Number> & elem_soln,
                                   std::vector<Number> & nodal_soln_on_side,
                                   const bool add_p_level,
                                   const unsigned vdim)
{
  std::vector<Number> full_nodal_soln;
  nodal_soln(elem, o, elem_soln, full_nodal_soln, add_p_level, vdim);
  const std::vector<unsigned int> side_nodes =
    elem->nodes_on_side(side);

  std::size_t n_side_nodes = side_nodes.size();
  nodal_soln_on_side.resize(n_side_nodes);
  for (auto n : make_range(n_side_nodes))
    nodal_soln_on_side[n] = full_nodal_soln[side_nodes[n]];
}




#ifdef LIBMESH_ENABLE_INFINITE_ELEMENTS

template <unsigned int Dim, FEFamily T>
void FE<Dim,T>::init_base_shape_functions(const std::vector<Point> & qp,
                                          const Elem * e)
{
  this->_elem = e;
  this->_elem_type = e->type();
  this->_fe_map->template init_reference_to_physical_map<Dim>(qp, e);
  init_shape_functions(qp, e);
}

#endif // LIBMESH_ENABLE_INFINITE_ELEMENTS


// Helper for FDM methods
namespace {
using namespace libMesh;

std::tuple<Point, Point, Real>
fdm_points(const unsigned int j, const Point & p)
{
  libmesh_assert_less (j, LIBMESH_DIM);

  // cheat by using finite difference approximations:
  const Real eps = 1.e-6;
  Point pp = p, pm = p;

  switch (j)
    {
      // d()/dxi
    case 0:
      {
        pp(0) += eps;
        pm(0) -= eps;
        break;
      }

      // d()/deta
    case 1:
      {
        pp(1) += eps;
        pm(1) -= eps;
        break;
      }

      // d()/dzeta
    case 2:
      {
        pp(2) += eps;
        pm(2) -= eps;
        break;
      }

    default:
      libmesh_error_msg("Invalid derivative index j = " << j);
    }

  return std::make_tuple(pp, pm, eps);
}

std::tuple<Point, Point, Real, unsigned int>
fdm_second_points(const unsigned int j, const Point & p)
{
  // cheat by using finite difference approximations:
  const Real eps = 1.e-5;
  Point pp = p, pm = p;
  unsigned int deriv_j = 0;

  switch (j)
    {
      //  d^2() / dxi^2
    case 0:
      {
        pp(0) += eps;
        pm(0) -= eps;
        deriv_j = 0;
        break;
      }

      // d^2() / dxi deta
    case 1:
      {
        pp(1) += eps;
        pm(1) -= eps;
        deriv_j = 0;
        break;
      }

      // d^2() / deta^2
    case 2:
      {
        pp(1) += eps;
        pm(1) -= eps;
        deriv_j = 1;
        break;
      }

      // d^2()/dxidzeta
    case 3:
      {
        pp(2) += eps;
        pm(2) -= eps;
        deriv_j = 0;
        break;
      }                  // d^2()/deta^2

      // d^2()/detadzeta
    case 4:
      {
        pp(2) += eps;
        pm(2) -= eps;
        deriv_j = 1;
        break;
      }

      // d^2()/dzeta^2
    case 5:
      {
        pp(2) += eps;
        pm(2) -= eps;
        deriv_j = 2;
        break;
      }

    default:
      libmesh_error_msg("Invalid shape function derivative j = " << j);
    }

  return std::make_tuple(pp, pm, eps, deriv_j);
}

}


template <typename OutputShape>
OutputShape fe_fdm_deriv(const Elem * elem,
                         const Order order,
                         const unsigned int i,
                         const unsigned int j,
                         const Point & p,
                         const bool add_p_level,
                         OutputShape(*shape_func)
                           (const Elem *, const Order,
                            const unsigned int, const Point &,
                            const bool))
{
  libmesh_assert(elem);

  auto [pp, pm, eps] = fdm_points(j, p);

  return (shape_func(elem, order, i, pp, add_p_level) -
          shape_func(elem, order, i, pm, add_p_level))/2./eps;
}


template <typename OutputShape>
OutputShape fe_fdm_deriv(const ElemType type,
                         const Order order,
                         const unsigned int i,
                         const unsigned int j,
                         const Point & p,
                         OutputShape(*shape_func)
                           (const ElemType, const Order,
                            const unsigned int, const Point &))
{
  auto [pp, pm, eps] = fdm_points(j, p);

  return (shape_func(type, order, i, pp) -
          shape_func(type, order, i, pm))/2./eps;
}


template <typename OutputShape>
OutputShape fe_fdm_deriv(const ElemType type,
                         const Order order,
                         const Elem * elem,
                         const unsigned int i,
                         const unsigned int j,
                         const Point & p,
                         OutputShape(*shape_func)
                           (const ElemType, const Order,
                            const Elem *,
                            const unsigned int, const Point &))
{
  auto [pp, pm, eps] = fdm_points(j, p);

  return (shape_func(type, order, elem, i, pp) -
          shape_func(type, order, elem, i, pm))/2./eps;
}


template <typename OutputShape>
OutputShape
fe_fdm_second_deriv(const Elem * elem,
                    const Order order,
                    const unsigned int i,
                    const unsigned int j,
                    const Point & p,
                    const bool add_p_level,
                    OutputShape(*deriv_func)
                      (const Elem *, const Order,
                       const unsigned int, const unsigned int,
                       const Point &, const bool))
{
  auto [pp, pm, eps, deriv_j] = fdm_second_points(j, p);

  return (deriv_func(elem, order, i, deriv_j, pp, add_p_level) -
          deriv_func(elem, order, i, deriv_j, pm, add_p_level))/2./eps;
}


template <typename OutputShape>
OutputShape
fe_fdm_second_deriv(const ElemType type,
                    const Order order,
                    const unsigned int i,
                    const unsigned int j,
                    const Point & p,
                    OutputShape(*deriv_func)
                      (const ElemType, const Order,
                       const unsigned int, const unsigned int,
                       const Point &))
{
  auto [pp, pm, eps, deriv_j] = fdm_second_points(j, p);

  return (deriv_func(type, order, i, deriv_j, pp) -
          deriv_func(type, order, i, deriv_j, pm))/2./eps;
}


template <typename OutputShape>
OutputShape
fe_fdm_second_deriv(const ElemType type,
                    const Order order,
                    const Elem * elem,
                    const unsigned int i,
                    const unsigned int j,
                    const Point & p,
                    OutputShape(*deriv_func)
                      (const ElemType, const Order,
                       const Elem *,
                       const unsigned int, const unsigned int,
                       const Point &))
{
  auto [pp, pm, eps, deriv_j] = fdm_second_points(j, p);

  return (deriv_func(type, order, elem, i, deriv_j, pp) -
          deriv_func(type, order, elem, i, deriv_j, pm))/2./eps;
}


void rational_fe_weighted_shapes(const Elem * elem,
                                 const FEType underlying_fe_type,
                                 std::vector<std::vector<Real>> & shapes,
                                 const std::vector<Point> & p,
                                 const bool add_p_level)
{
  const int extra_order = add_p_level * elem->p_level();

  const int dim = elem->dim();

  const unsigned int n_sf =
    FEInterface::n_shape_functions(underlying_fe_type, extra_order,
                                   elem);

  libmesh_assert_equal_to (n_sf, elem->n_nodes());

  std::vector<Real> node_weights(n_sf);

  const unsigned char datum_index = elem->mapping_data();
  for (unsigned int n=0; n<n_sf; n++)
    node_weights[n] =
      elem->node_ref(n).get_extra_datum<Real>(datum_index);

  const std::size_t n_p = p.size();

  shapes.resize(n_sf);
  for (unsigned int i=0; i != n_sf; ++i)
    {
      auto & shapes_i = shapes[i];
      shapes_i.resize(n_p, 0);
      FEInterface::shapes(dim, underlying_fe_type, elem, i, p,
                          shapes_i, add_p_level);
      for (auto & s : shapes_i)
        s *= node_weights[i];
    }
}


void rational_fe_weighted_shapes_derivs(const Elem * elem,
                                        const FEType fe_type,
                                        std::vector<std::vector<Real>> & shapes,
                                        std::vector<std::vector<std::vector<Real>>> & derivs,
                                        const std::vector<Point> & p,
                                        const bool add_p_level)
{
  const int extra_order = add_p_level * elem->p_level();
  const unsigned int dim = elem->dim();

  const unsigned int n_sf =
    FEInterface::n_shape_functions(fe_type, extra_order, elem);

  libmesh_assert_equal_to (n_sf, elem->n_nodes());

  libmesh_assert_equal_to (dim, derivs.size());
  for (unsigned int d = 0; d != dim; ++d)
    derivs[d].resize(n_sf);

  std::vector<Real> node_weights(n_sf);

  const unsigned char datum_index = elem->mapping_data();
  for (unsigned int n=0; n<n_sf; n++)
    node_weights[n] =
      elem->node_ref(n).get_extra_datum<Real>(datum_index);

  const std::size_t n_p = p.size();

  shapes.resize(n_sf);
  for (unsigned int i=0; i != n_sf; ++i)
    shapes[i].resize(n_p, 0);

  FEInterface::all_shapes(dim, fe_type, elem, p, shapes, add_p_level);

  for (unsigned int i=0; i != n_sf; ++i)
    {
      auto & shapes_i = shapes[i];

      for (auto & s : shapes_i)
        s *= node_weights[i];

      for (unsigned int d = 0; d != dim; ++d)
        {
          auto & derivs_di = derivs[d][i];
          derivs_di.resize(n_p);
          FEInterface::shape_derivs(fe_type, elem, i, d, p,
                                    derivs_di, add_p_level);
          for (auto & dip : derivs_di)
            dip *= node_weights[i];
        }
    }
}


Real rational_fe_shape(const Elem & elem,
                       const FEType underlying_fe_type,
                       const unsigned int i,
                       const Point & p,
                       const bool add_p_level)
{
  int extra_order = add_p_level * elem.p_level();

  const unsigned int n_sf =
    FEInterface::n_shape_functions(underlying_fe_type, extra_order, &elem);

  libmesh_assert_equal_to (n_sf, elem.n_nodes());

  std::vector<Real> node_weights(n_sf);

  const unsigned char datum_index = elem.mapping_data();

  Real weighted_shape_i = 0, weighted_sum = 0;

  for (unsigned int sf=0; sf<n_sf; sf++)
    {
      Real node_weight =
        elem.node_ref(sf).get_extra_datum<Real>(datum_index);
      Real weighted_shape = node_weight *
        FEInterface::shape(underlying_fe_type, extra_order, &elem, sf, p);
      weighted_sum += weighted_shape;
      if (sf == i)
        weighted_shape_i = weighted_shape;
    }

  return weighted_shape_i / weighted_sum;
}


Real rational_fe_shape_deriv(const Elem & elem,
                             const FEType underlying_fe_type,
                             const unsigned int i,
                             const unsigned int j,
                             const Point & p,
                             const bool add_p_level)
{
  libmesh_assert_less(j, elem.dim());

  int extra_order = add_p_level * elem.p_level();

  const unsigned int n_sf =
    FEInterface::n_shape_functions(underlying_fe_type, extra_order, &elem);

  const unsigned int n_nodes = elem.n_nodes();
  libmesh_assert_equal_to (n_sf, n_nodes);

  std::vector<Real> node_weights(n_nodes);

  const unsigned char datum_index = elem.mapping_data();
  for (unsigned int n=0; n<n_nodes; n++)
    node_weights[n] =
      elem.node_ref(n).get_extra_datum<Real>(datum_index);

  Real weighted_shape_i = 0, weighted_sum = 0,
       weighted_grad_i = 0, weighted_grad_sum = 0;

  for (unsigned int sf=0; sf<n_sf; sf++)
    {
      Real weighted_shape = node_weights[sf] *
        FEInterface::shape(underlying_fe_type, extra_order, &elem, sf, p);
      Real weighted_grad = node_weights[sf] *
        FEInterface::shape_deriv(underlying_fe_type, extra_order, &elem, sf, j, p);
      weighted_sum += weighted_shape;
      weighted_grad_sum += weighted_grad;
      if (sf == i)
        {
          weighted_shape_i = weighted_shape;
          weighted_grad_i = weighted_grad;
        }
    }

  return (weighted_sum * weighted_grad_i - weighted_shape_i * weighted_grad_sum) /
         weighted_sum / weighted_sum;
}


#ifdef LIBMESH_ENABLE_SECOND_DERIVATIVES

Real rational_fe_shape_second_deriv(const Elem & elem,
                                    const FEType underlying_fe_type,
                                    const unsigned int i,
                                    const unsigned int j,
                                    const Point & p,
                                    const bool add_p_level)
{
  unsigned int j1, j2;
  switch (j)
    {
    case 0:
      // j = 0 ==> d^2 phi / dxi^2
      j1 = j2 = 0;
      break;
    case 1:
      // j = 1 ==> d^2 phi / dxi deta
      j1 = 0;
      j2 = 1;
      break;
    case 2:
      // j = 2 ==> d^2 phi / deta^2
      j1 = j2 = 1;
      break;
    case 3:
      // j = 3 ==> d^2 phi / dxi dzeta
      j1 = 0;
      j2 = 2;
      break;
    case 4:
      // j = 4 ==> d^2 phi / deta dzeta
      j1 = 1;
      j2 = 2;
      break;
    case 5:
      // j = 5 ==> d^2 phi / dzeta^2
      j1 = j2 = 2;
      break;
    default:
      libmesh_error();
    }

  int extra_order = add_p_level * elem.p_level();

  const unsigned int n_sf =
    FEInterface::n_shape_functions(underlying_fe_type, extra_order,
                                   &elem);

  const unsigned int n_nodes = elem.n_nodes();
  libmesh_assert_equal_to (n_sf, n_nodes);

  std::vector<Real> node_weights(n_nodes);

  const unsigned char datum_index = elem.mapping_data();
  for (unsigned int n=0; n<n_nodes; n++)
    node_weights[n] =
      elem.node_ref(n).get_extra_datum<Real>(datum_index);

  Real weighted_shape_i = 0, weighted_sum = 0,
       weighted_grada_i = 0, weighted_grada_sum = 0,
       weighted_gradb_i = 0, weighted_gradb_sum = 0,
       weighted_hess_i = 0, weighted_hess_sum = 0;

  for (unsigned int sf=0; sf<n_sf; sf++)
    {
      Real weighted_shape = node_weights[sf] *
        FEInterface::shape(underlying_fe_type, extra_order, &elem, sf,
                           p);
      Real weighted_grada = node_weights[sf] *
        FEInterface::shape_deriv(underlying_fe_type, extra_order,
                                 &elem, sf, j1, p);
      Real weighted_hess = node_weights[sf] *
        FEInterface::shape_second_deriv(underlying_fe_type,
                                        extra_order, &elem, sf, j, p);
      weighted_sum += weighted_shape;
      weighted_grada_sum += weighted_grada;
      Real weighted_gradb = weighted_grada;
      if (j1 != j2)
        {
          weighted_gradb = (j1 == j2) ? weighted_grada :
            node_weights[sf] *
            FEInterface::shape_deriv(underlying_fe_type, extra_order,
                                     &elem, sf, j2, p);
          weighted_grada_sum += weighted_grada;
        }
      weighted_hess_sum += weighted_hess;
      if (sf == i)
        {
          weighted_shape_i = weighted_shape;
          weighted_grada_i = weighted_grada;
          weighted_gradb_i = weighted_gradb;
          weighted_hess_i = weighted_hess;
        }
    }

  if (j1 == j2)
    weighted_gradb_sum = weighted_grada_sum;

  return (weighted_sum * weighted_hess_i - weighted_grada_i * weighted_gradb_sum -
          weighted_shape_i * weighted_hess_sum - weighted_gradb_i * weighted_grada_sum +
          2 * weighted_grada_sum * weighted_shape_i * weighted_gradb_sum / weighted_sum) /
         weighted_sum / weighted_sum;
}

#endif // LIBMESH_ENABLE_SECOND_DERIVATIVES


void rational_all_shapes (const Elem & elem,
                          const FEType underlying_fe_type,
                          const std::vector<Point> & p,
                          std::vector<std::vector<Real>> & v,
                          const bool add_p_level)
{
  std::vector<std::vector<Real>> shapes;

  rational_fe_weighted_shapes(&elem, underlying_fe_type, shapes, p,
                              add_p_level);

  std::vector<Real> shape_sums(p.size(), 0);

  for (auto i : index_range(v))
    {
      libmesh_assert_equal_to ( p.size(), shapes[i].size() );
      for (auto j : index_range(p))
        shape_sums[j] += shapes[i][j];
    }

  for (auto i : index_range(v))
    {
      libmesh_assert_equal_to ( p.size(), v[i].size() );
      for (auto j : index_range(v[i]))
        v[i][j] = shapes[i][j] / shape_sums[j];
    }
}


template <typename OutputShape>
void rational_all_shape_derivs (const Elem & elem,
                                const FEType underlying_fe_type,
                                const std::vector<Point> & p,
                                std::vector<std::vector<OutputShape>> * comps[3],
                                const bool add_p_level)
{
  const int my_dim = elem.dim();

  std::vector<std::vector<Real>> shapes;
  std::vector<std::vector<std::vector<Real>>> derivs(my_dim);

  rational_fe_weighted_shapes_derivs(&elem, underlying_fe_type,
                                     shapes, derivs, p, add_p_level);

  std::vector<Real> shape_sums(p.size(), 0);
  std::vector<std::vector<Real>> shape_deriv_sums(my_dim);
  for (int d=0; d != my_dim; ++d)
    shape_deriv_sums[d].resize(p.size());

  for (auto i : index_range(shapes))
    {
      libmesh_assert_equal_to ( p.size(), shapes[i].size() );
      for (auto j : index_range(p))
        shape_sums[j] += shapes[i][j];

      for (int d=0; d != my_dim; ++d)
        for (auto j : index_range(p))
          shape_deriv_sums[d][j] += derivs[d][i][j];
    }

  for (int d=0; d != my_dim; ++d)
    {
      auto & comps_d = *comps[d];
      libmesh_assert_equal_to(comps_d.size(), elem.n_nodes());

      for (auto i : index_range(comps_d))
        {
          auto & comps_di = comps_d[i];
          auto & derivs_di = derivs[d][i];

          for (auto j : index_range(comps_di))
            comps_di[j] = (shape_sums[j] * derivs_di[j] -
              shapes[i][j] * shape_deriv_sums[d][j]) /
              shape_sums[j] / shape_sums[j];
        }
    }
}



template
Real fe_fdm_deriv<Real>(const Elem *, const Order, const unsigned int,
                        const unsigned int, const Point &, const bool,
                        Real(*shape_func)
                          (const Elem *, const Order, const unsigned int,
                           const Point &, const bool));

template
Real fe_fdm_deriv<Real>(const ElemType, const Order, const unsigned int,
                        const unsigned int, const Point &,
                        Real(*shape_func)
                          (const ElemType, const Order, const unsigned int,
                           const Point &));

template
Real fe_fdm_deriv<Real>(const ElemType, const Order, const Elem *,
                        const unsigned int, const unsigned int, const Point &,
                        Real(*shape_func)
                          (const ElemType, const Order, const Elem *,
                           const unsigned int, const Point &));

template
RealGradient
fe_fdm_deriv<RealGradient>(const Elem *, const Order, const unsigned int,
                           const unsigned int, const Point &, const bool,
                           RealGradient(*shape_func)
                             (const Elem *, const Order, const unsigned int,
                              const Point &, const bool));

template
Real
fe_fdm_second_deriv<Real>(const ElemType, const Order, const unsigned int,
                          const unsigned int, const Point &,
                          Real(*shape_func)
                            (const ElemType, const Order, const unsigned int,
                             const unsigned int, const Point &));

template
Real
fe_fdm_second_deriv<Real>(const Elem *, const Order, const unsigned int,
                          const unsigned int, const Point &, const bool,
                          Real(*shape_func)
                            (const Elem *, const Order, const unsigned int,
                             const unsigned int, const Point &, const bool));

template
Real
fe_fdm_second_deriv<Real>(const ElemType, const Order, const Elem *,
                          const unsigned int, const unsigned int, const Point &,
                          Real(*shape_func)
                            (const ElemType, const Order, const Elem *,
                             const unsigned int, const unsigned int, const Point &));

template
RealGradient
fe_fdm_second_deriv<RealGradient>(const Elem *, const Order, const unsigned int,
                           const unsigned int, const Point &, const bool,
                           RealGradient(*shape_func)
                             (const Elem *, const Order, const unsigned int,
                              const unsigned int, const Point &, const bool));




//--------------------------------------------------------------
// Explicit instantiations using macro from fe_macro.h

INSTANTIATE_FE(0);

INSTANTIATE_FE(1);

INSTANTIATE_FE(2);

INSTANTIATE_FE(3);

INSTANTIATE_SUBDIVISION_FE;

template LIBMESH_EXPORT void rational_all_shape_derivs<Real> (const Elem & elem, const FEType underlying_fe_type, const std::vector<Point> & p, std::vector<std::vector<Real>> * comps[3], const bool add_p_level);
} // namespace libMesh
